-
Notifications
You must be signed in to change notification settings - Fork 197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add TTFT benchmarks + update sparsity benchmarks #1140
Conversation
Summary: This PR adds in a sparsity option to the LLaMa benchmarks. Test Plan: Reviewers: Subscribers: Tasks: Tags:
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1140
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 1 Unrelated FailureAs of commit de2d447 with merge base 2f97b09 (): NEW FAILURE - The following job has failed:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
torchao/_models/llama/generate.py
Outdated
from torchao.dtypes import MarlinSparseLayout | ||
quantize_(model, int4_weight_only(layout=MarlinSparseLayout())) | ||
if sparsity and "semi" in sparsity: | ||
quantize_(model, int4_weight_only(layout=MarlinSparseLayout())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this isn't using any of the derived variables. It should use the derived ones or be in a separate section.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm if you move the marlin stuff so its clearer what derived variables it actually uses
torchao/_models/llama/generate.py
Outdated
print(f"Peak Memory Usage: {mem:.02f} GB") | ||
print(f"Model Size: {model_size:.02f} GB") | ||
if write_result: | ||
result_txt = f"\n{datetime.today().strftime('%Y%m%d%H%M%S')}, tok/s={tokpersec:6.2f}, mem/s={bandwidth:7.2f} GB/s, peak_mem={mem:5.2f} GB, model_size={model_size:5.2f} GB " | ||
result_txt += f"quant: {quantization}, mod: {checkpoint_path.parent.name}, kv_quant: {kv_cache_quantization}, compile: {compile}, compile_prefill: {compile_prefill}, dtype: {precision}, device: {device} " | ||
result_txt = f"\n{datetime.today().strftime('%Y%m%d%H%M%S')}, tok/s={tokpersec:6.2f}, mem/s={bandwidth:7.2f} GB/s, time={t:5.4f} sec, peak_mem={mem:5.2f} GB, model_size={model_size:5.2f} GB " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
time is a really generic term, is this TTFT or overall run? the tok/s info is already the non prefill indicator so TTFT or time to do prefill is probably more valuable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's overall time, but I limit num_tokens to 1. I can make this a bit clearer though, maybe a --ttft flag that sets forces num_tokens_generated to be 1.
6858180
to
4fdfa7b
Compare
This PR adds in TTFT token benchmarks to torchAO, and also updates the benchmarking script to handle sparsity a bit nicer + use the 2:4 sparse checkpoints that are available. Additionally also adds in padding support for int8 dynamic quant + 2:4 sparsity, which we were missing before.
Hi @vkuzo Thanks for the great work! |
* Torchchat CLI pipeline for Multimodal Models * Remove torchaudio check; we don't use it * Flip the imports back for ET --------- Co-authored-by: vmpuri <[email protected]> Co-authored-by: Jack-Khuu <[email protected]>
It's available as of last night! |
Thanks! |
This PR adds in TTFT token benchmarks to torchAO, and also updates the benchmarking script to handle sparsity a bit nicer + use the 2:4 sparse checkpoints that are available.
Additionally also adds in padding support for int8 dynamic quant + 2:4 sparsity, which we were missing before.